{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1000, cost:0.002649\n",
      "epoch: 2000, cost:0.000670\n",
      "epoch: 3000, cost:0.000276\n",
      "epoch: 4000, cost:0.000134\n",
      "epoch: 5000, cost:0.000070\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n",
    "labels = [1,1,1,0,0,0]\n",
    "#paramater\n",
    "embedding_size = 2\n",
    "n_hidden = 5\n",
    "n_step = 3\n",
    "n_class = 2\n",
    "\n",
    "word_list = ' '.join(sentences).split()\n",
    "word_list = list(set(word_list))\n",
    "word_num = {w:i for i,w in enumerate(word_list)}\n",
    "num_word = {i:w for i,w in enumerate(word_list)}\n",
    "vocab_size = len(word_list)\n",
    "\n",
    "input_batch = []\n",
    "target_batch = []\n",
    "for sen in sentences:\n",
    "    input_batch.append([word_num[n] for n in sen.split()])\n",
    "\n",
    "for out in labels:\n",
    "    target_batch.append(np.eye(n_class)[out])\n",
    "    \n",
    "X = tf.placeholder(tf.int32,[None,n_step])\n",
    "Y = tf.placeholder(tf.int32,[None,n_class])\n",
    "out = tf.Variable(tf.random_normal([n_hidden*2,n_class]))\n",
    "\n",
    "embedding = tf.Variable(tf.random_uniform([vocab_size,embedding_size]))\n",
    "cell_in = tf.nn.embedding_lookup(embedding,X)\n",
    "\n",
    "#model\n",
    "fw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden)\n",
    "bw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden)\n",
    "\n",
    "hidden_outputs,final_state = tf.nn.bidirectional_dynamic_rnn(fw_cell,bw_cell,cell_in,dtype=tf.float32)\n",
    "#hidden_outputs[2,batch_size,n_step,hidden_size]\n",
    "#final_state [2,2,batch_size,hidden_size]\n",
    "output = tf.concat([hidden_outputs[0],hidden_outputs[1]],2)#6,3,10\n",
    "final_state = tf.concat([final_state[1][0],final_state[1][1]],1)#6,10\n",
    "final_state = tf.expand_dims(final_state,2)\n",
    "attn_weight = tf.squeeze(tf.matmul(output,final_state),2)#6,3\n",
    "soft_attn_weights = tf.nn.softmax(attn_weight, 1)#6,3\n",
    "context = tf.matmul(tf.transpose(output,[0,2,1]),tf.expand_dims(soft_attn_weights,2))#6,10,1\n",
    "context = tf.squeeze(context,2)#6,10\n",
    "model = tf.matmul(context,out)#batch n_class\n",
    "\n",
    "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model,labels=Y))\n",
    "optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)\n",
    "\n",
    "predict = tf.argmax(tf.nn.softmax(model),1)\n",
    "\n",
    "\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "for i in range(5000):\n",
    "    _,loss = sess.run([optimizer,cost],feed_dict={X:input_batch,Y:target_batch})\n",
    "    \n",
    "    if (i+1) %1000==0:\n",
    "        print('epoch: %d, cost:%.6f' % ((i+1),loss))\n",
    "    \n",
    "test = 'sorry hate you'\n",
    "tests = [np.asarray([word_num[n] for n in test.split()])]\n",
    "pre = sess.run(predict,feed_dict={X:input_batch,Y:target_batch})\n",
    "print(pre[0])\n",
    "              \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
